package com.diven.common.tree.utils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.UUID;

import com.diven.common.tree.common.Constant;
import com.diven.common.tree.model.BinPoint;
import com.diven.common.tree.model.BinPoints;
import com.diven.common.tree.model.ClassifierTree;
import com.diven.common.tree.model.FeatureBinsLabel;
import com.diven.common.tree.model.NominalPoint;
import com.diven.common.tree.model.NumericPoint;
import com.diven.common.tree.model.RuleInfo;
import com.diven.common.tree.weka.WekaREPTree;

import weka.core.Attribute;
import weka.core.AttributeStats;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;

public class TreeUtil {
	
	public static ClassifierTree getEmptyClassifierTree(Instances instances) {
		ClassifierTree classifierTree = new ClassifierTree();
		classifierTree.setInstances(new Instances(instances, 0));
		classifierTree.setDistributions(new double[instances.numClasses()]);
		classifierTree.setClassProbs(new double[instances.numClasses()]);
		return classifierTree;
	}
	
	/**
	 * 复制树对象
	 * @param wekaTree
	 * @return
	 */
	public static ClassifierTree toClassifierTree(WekaREPTree.Tree wekaTree) {
		ClassifierTree classifierTree = null;
		if(wekaTree != null) {
			//复制数据信息
			classifierTree = new ClassifierTree();
			classifierTree.setInstances(wekaTree.m_Info);
			classifierTree.setSplitAttribute(wekaTree.m_Attribute);
			if(wekaTree.m_Attribute >= 0) {
				BinPoints splitPoints = new BinPoints();
				Attribute splitAttribute = wekaTree.m_Info.attribute(wekaTree.m_Attribute);
				if(splitAttribute.isNominal()) {
					for(int i=0;i<splitAttribute.numValues();i++) {
						NominalPoint point = new NominalPoint();
						point.add(splitAttribute.value(i));
						splitPoints.add(point);
					}
				}
				else {
					NumericPoint point = new NumericPoint();
					point.setPoint(wekaTree.m_SplitPoint);
					splitPoints.add(point);
				}
				classifierTree.setSplitPoints(splitPoints);
			}
			classifierTree.setInstancesProps(wekaTree.m_Prop);
			classifierTree.setClassProbs(wekaTree.m_ClassProbs);
			classifierTree.setDistributions(wekaTree.m_Distribution);
			if(wekaTree.m_Successors != null && wekaTree.m_Successors.length > 0) {
				ClassifierTree[] childNodes = new ClassifierTree[wekaTree.m_Successors.length];
				for(int i=0; i<wekaTree.m_Successors.length; i++) {
					childNodes[i] = toClassifierTree(wekaTree.m_Successors[i]);
				}
				classifierTree.setChildNodes(childNodes);
			}
		}
		//返回数据
		return classifierTree;
	}
	
	/**
	 * 对当前树进行索引
	 * @param tree	当前树
	 */
	public static void indexTreeNo(ClassifierTree tree) {
		indexTreeNo(tree, 1);
	}
	
	/**
	 * 对当前树进行索引
	 * @param tree	当前树
	 * @param stratTreeNo	当前的编号
	 */
	public static int indexTreeNo(ClassifierTree tree, int stratTreeNo) {
		tree.setTreeNo(stratTreeNo);
		if(tree.getUUID() == null) {
			tree.setUUID(UUID.randomUUID().toString().toLowerCase().replace("-", ""));
		}
		if(tree.getChildNodes() != null && tree.getChildNodes().length > 0) {
			for(ClassifierTree item : tree.getChildNodes()) {
				stratTreeNo = indexTreeNo(item, ++ stratTreeNo);
			}
		}
		return stratTreeNo;
	}
	
	/**
	 * 根据treeNo获取树节点
	 * @param tree	树
	 * @param treeNo	节点编号
	 * @return
	 */
	public static ClassifierTree getNodeByTreeNo(ClassifierTree tree, int treeNo) {
		if(tree.getTreeNo().equals(treeNo)) {
			return tree;
		}
		else if(tree.getChildNodes() != null && tree.getChildNodes().length > 0) {
			for(ClassifierTree item : tree.getChildNodes()) {
				ClassifierTree temp =  getNodeByTreeNo(item, treeNo);
				if(temp != null) {
					return temp;
				}
			}
		}
		return null;
	}
	
	/**
	 * 根据列表获取Attribute
	 * @param attrs
	 * @return
	 */
	public static ArrayList<Attribute> getAttributes(Instances instances, String...  attrs ){
		return getAttributes(instances, Arrays.asList(attrs));
	}
	
	
	/**
	 * 根据列表获取Attribute
	 * @param attrs
	 * @return
	 */
	public static ArrayList<Attribute> getAttributes(Instances instances, List<String> attrs){
		ArrayList<Attribute> attributes = new ArrayList<>();
		for(String attr : attrs) {
			Attribute attribute = instances.attribute(attr);
			Attribute thisAttr = null;
			if(attribute.isNumeric()) {
				thisAttr = new Attribute(attr);
			}
			else {
				List<String> values = new ArrayList<>();
				for(int i=0;i<attribute.numValues();i++) {
					values.add(attribute.value(i));
				}
				thisAttr = new Attribute(attr, values);
			}
			attributes.add(thisAttr);
		}
		return attributes;
	}
	
	/**
	 * 根据原始信息，获取属性列表
	 * @param data
	 * @return
	 */
	public static ArrayList<Attribute> getAttributesByInstances(Instances data) {
        ArrayList<Attribute> attrsList = new ArrayList<>();
        for (int i = 0; i < data.numAttributes(); i++) {
            attrsList.add(data.attribute(i));
        }
        return attrsList;
    }
	
	/**
	 * 获取所有的叶子节点
	 * @param tree
	 * @return
	 */
	public static List<ClassifierTree> getLeafNodes(ClassifierTree normalClassifierTree, ClassifierTree missingClassifierTree){
		List<ClassifierTree> leafNodes = new ArrayList<>();
		if(normalClassifierTree.getSplitAttribute() < 0) {
			leafNodes.add(normalClassifierTree);
		}
		else {
			for(ClassifierTree item : normalClassifierTree.getChildNodes()) {
				leafNodes.addAll(getLeafNodes(item, null));
			}
		}
		if(missingClassifierTree != null) {
			leafNodes.add(missingClassifierTree);
		}
		return leafNodes;
	}
	
	/**
	 * 获取单特征分箱的分位点
	 * @param tree
	 * @return
	 */
	public static BinPoints getBinPointBySingleFeatureTree(Attribute feature, ClassifierTree normalClassifierTree, ClassifierTree missingClassifierTree) {
		//获取分位点
		BinPoints points = new BinPoints();
		if(normalClassifierTree.getSplitAttribute() >= 0) {
			//添加当前的分位点
			points.addAll(normalClassifierTree.getSplitPoints());
			for(ClassifierTree item : normalClassifierTree.getChildNodes()) {
				//添加子节点的分位点
				BinPoints temp = getBinPointBySingleFeatureTree(feature, item, null);
				points.addAll(temp);
			}
			//排序
			if(feature.isNumeric()) {
				Collections.sort(points, new Comparator<BinPoint<?>>() {
					@Override
					public int compare(BinPoint<?> o1, BinPoint<?> o2) {
						NumericPoint m1 = NumericPoint.class.cast(o1);
						NumericPoint m2 = NumericPoint.class.cast(o2);
						return m1.getPoint().compareTo(m2.getPoint());
					}
				});
			}
		}
		if(missingClassifierTree != null && feature.isNumeric()) {
			NumericPoint point = new NumericPoint();
			point.setPoint(Double.NaN);
			points.add(point);
		}
		//返回分位点
		return points;
	}
	
	/**
	 * 获取单特征树的叶子节点的信息增益
	 * @param leafs	叶子节点
	 * @return
	 */
	public static double gainBySingleFeatureTreeLeaf(Attribute feature, List<ClassifierTree> leafs) {
		//初始化对象
		int numClasses = leafs.get(0).getInstances().numClasses();
		double gains[][] = new double[leafs.size()][numClasses];
		for(int i=0; i<leafs.size(); i++) {
			gains[i] = leafs.get(i).getDistributions();
		}
		return MathUtil.gain(gains, MathUtil.priorVal(gains));
	}
	
	/**
	 * 计算每个分支的训练实例的比例 
	 * @param nodeCount
	 * @param leafs
	 * @return
	 */
	public static double[] getInstancesPropsByNodeLeaf(int nodeCount, List<ClassifierTree> leafs) {
		if(leafs != null && !leafs.isEmpty() && nodeCount > 0) {
			double instancesProps[] = new double[leafs.size()];
			for(int i=0; i<leafs.size(); i++) {
				instancesProps[i] = leafs.get(i).getInstances().size() * 1.0/nodeCount;
			}
			return instancesProps;
		}
		else {
			return null;
		}
	}
	
	/**
	 * 分箱中是否包含空箱体
	 * @param binPoints
	 * @return
	 */
	public static boolean hasMissingBin(BinPoints binPoints) {
		for(BinPoint<?> binPoint : binPoints) {
			if(binPoint instanceof NumericPoint) {
				NumericPoint item = NumericPoint.class.cast(binPoint);
				if(Double.isNaN(item.getPoint())) {
					return true;
				}
			}
			else if(binPoint instanceof NominalPoint) {
				NominalPoint item = NominalPoint.class.cast(binPoint);
				ArrayList<String> itemBins = item.getPoint();
				for(String itemBin : itemBins) {
					if(Constant.NOMINAL_NULL.equals(itemBin)) {
						return true;
					}
				}
			}
		}
		return false;
	}
	
	/**
	 * 获取分箱
	 * @param binPoints
	 * @return
	 */
	public static List<String> getBinsToLabel(BinPoints binPoints) {
		List<String> labels = new ArrayList<>();
		if(binPoints != null && binPoints.size() > 0) {
			BinPoint<?> binPoint1 = binPoints.get(0);
			if(binPoint1 instanceof NominalPoint) {
				for(BinPoint<?> item : binPoints) {
					StringBuffer label = new StringBuffer();
					NominalPoint nominalPoint = NominalPoint.class.cast(item);
					for(String point : nominalPoint.getPoint()) {
						label.append(Utils.backQuoteChars(point)).append("\n");
					}
					labels.add(label.toString().trim());
				}
			}
			else {
				List<BinPoint<?>> subBinPoints = null;
				boolean hasMissingBin = hasMissingBin(binPoints);
				if(hasMissingBin) {
					subBinPoints = new ArrayList<>(binPoints.subList(0, binPoints.size()-1));
				}
				else {
					subBinPoints = new ArrayList<>(binPoints);
				}
				//添加最小与最大边界
				NumericPoint minBinvalue = new NumericPoint();
				minBinvalue.setPoint(Double.NEGATIVE_INFINITY);
				subBinPoints.add(0, minBinvalue);
				NumericPoint maxBinvalue = new NumericPoint();
				maxBinvalue.setPoint(Double.POSITIVE_INFINITY);
				subBinPoints.add(maxBinvalue);
				for(int i=0;i<subBinPoints.size()-1;i++) {
					NumericPoint min = NumericPoint.class.cast(subBinPoints.get(i));
					NumericPoint max = NumericPoint.class.cast(subBinPoints.get(i+1));
					StringBuffer label = new StringBuffer();
					label.append("[")
					.append(Utils.doubleToString(min.getPoint(), 2)).append(", ")
					.append(Utils.doubleToString(max.getPoint(), 2)).append(")");
					labels.add(label.toString().trim());
				}
				if(hasMissingBin) {
					labels.add("["+Utils.doubleToString(Double.NaN, 2)+"]");
				}
			}
		}
		return labels;
	}
	
	/**
	 * 获取当前值所在的分箱编号
	 * @param attribute
	 * @param instance
	 * @param binPoints
	 * @return
	 */
	public static int getBinNumForAttribute(Attribute attribute, Instance instance, BinPoints binPoints) {
		if(attribute.isNominal()) {
			String value = instance.stringValue(attribute);
			for(int i=0;i<binPoints.size();i++) {
				NominalPoint item = NominalPoint.class.cast(binPoints.get(i));
				List<String> itemBins = item.getPoint();
				if(itemBins.contains(value)) {
					return i;
				}
			}
		}
		else if(attribute.isNumeric()) {
			boolean hasMissingBin = hasMissingBin(binPoints);
			//获取分箱点，排除空箱
			List<BinPoint<?>> subBinPoints = null;
			if(hasMissingBin) {
				if(instance.isMissing(attribute)) {
					return binPoints.size();
				}
				subBinPoints = new ArrayList<>(binPoints.subList(0, binPoints.size()-1));
			}
			else {
				subBinPoints = new ArrayList<>(binPoints);
			}
			//添加最小与最大边界
			NumericPoint minBinvalue = new NumericPoint();
			minBinvalue.setPoint(Double.NEGATIVE_INFINITY);
			subBinPoints.add(0, minBinvalue);
			NumericPoint maxBinvalue = new NumericPoint();
			maxBinvalue.setPoint(Double.POSITIVE_INFINITY);
			subBinPoints.add(maxBinvalue);
			//初始化参数变量
			double value = instance.value(attribute);
			for(int i=0;i<subBinPoints.size()-1;i++) {
				NumericPoint min = NumericPoint.class.cast(subBinPoints.get(i));
				NumericPoint max = NumericPoint.class.cast(subBinPoints.get(i+1));
				if(value >= min.getPoint() && value < max.getPoint()) {
					return i;
				}
			}
		}
		//异常分箱
		return -1;
	}
	
	/**
	 * 获取节点图信息
	 * @param classifierTree
	 * @param sampleCount
	 * @return
	 */
	public static String getDotTableByTree(ClassifierTree classifierTree, int sampleCount) {
		//组织数据
		Object datas[][] = new Object[classifierTree.getDistributions().length+1][3];
		//设置标签
		Instances nodeInstances = classifierTree.getInstances();
		Attribute att = nodeInstances.classAttribute();
		AttributeStats attributeStats = nodeInstances.attributeStats(nodeInstances.classIndex());
		for (int i = 0; i < attributeStats.nominalCounts.length; i++) {
			datas[i][0] = att.value(i);
		}
		datas[datas.length-1][0] = "total";
		//设置数据量
		for(int i = 0; i < classifierTree.getDistributions().length; i++) {
    		datas[i][1] = classifierTree.getDistributions()[i];
    	}
		datas[datas.length-1][1] = Utils.sum(classifierTree.getDistributions());
		//设置占比
		if(classifierTree.getClassProbs() != null) {
			for(int i = 0; i < classifierTree.getClassProbs().length; i++) {
				datas[i][2] = classifierTree.getClassProbs()[i];
			}
			datas[datas.length-1][2] = Utils.sum(classifierTree.getDistributions())/sampleCount;
		}
		else {
			for(int i=0; i<datas.length; i++) {
				datas[i][2] = 0;
			}
		}
		//生成图
		StringBuffer bf = new StringBuffer();
		for(int i=0; i< datas.length; i++) {
			//标签
			int width = 6;
			bf.append("| ");
			if(i == datas.length -1) {
				bf.append(format("total", width, true));
			}
			else if(datas[i][0] instanceof Number) {
				Number number = Number.class.cast(datas[i][0]);
				bf.append(format(Utils.doubleToString(number.doubleValue(), 1), width, true));
			}
			else {
				bf.append(format(String.valueOf(datas[i][0]), width, true));
			}
			//数量
			width = 10;
			bf.append("| ");
			if(datas[i][1] instanceof Number) {
				Number number = Number.class.cast(datas[i][1]);
				bf.append(format(Utils.doubleToString(number.doubleValue(), 1), width, true));
			}
			else {
				bf.append(format(String.valueOf(datas[i][1]), width, true));
			}
			//百分比
			width = 6;
			bf.append("| ");
			if(datas[i][2] instanceof Number) {
				Number number = Number.class.cast(datas[i][2]);
				bf.append(format(Utils.doubleToString(number.doubleValue() * 100, 2)+"%", width, true));
			}
			else {
				bf.append(format(String.valueOf(datas[i][2]), width, true));
			}
			bf.append("| ");
			bf.append("\\n");
		}
		return bf.toString();
	}
	
	/**
	 * 格式化字符串
	 * @param str	字符串
	 * @param width	宽度
	 * @param isRigth	是否在右边
	 * @return
	 */
	public static String format(String str, int width, boolean isRigth) {
		if(str.length() < width) {
			int temp = width - str.length();
			for(int i = 0; i < temp; i ++) {
				if(isRigth) {
					str = str + " ";
				}
				else {
					str = " " + str;
				}
			}
		}
		return str;
	}
	
	/**
	 * 获取叶子节点的规则信息
	 * @param sampleCount
	 * @param classifierTree
	 * @return
	 */
	public static List<RuleInfo> getRuleInfoForleafNodes(int sampleCount, int overdueLabelIndex, double sampleOverdueRate, ClassifierTree classifierTree){
		return getRuleInfoForleafNodes(sampleCount, overdueLabelIndex, sampleOverdueRate, classifierTree, new LinkedList<FeatureBinsLabel>());
	}
	
	/**
	 * 获取叶子节点的规则信息
	 * @param sampleCount
	 * @param classifierTree
	 * @return
	 */
	private static List<RuleInfo> getRuleInfoForleafNodes(int sampleCount, int overdueLabelIndex, double sampleOverdueRate, ClassifierTree classifierTree, LinkedList<FeatureBinsLabel> binsLabel){
		List<RuleInfo> infos = new ArrayList<>();
		if(classifierTree.getSplitAttribute() < 0) {
			//叶子节点,并且不是跟节点
			if(classifierTree.getTreeNo()!= null && classifierTree.getTreeNo() > 1) {
				RuleInfo ruleInfo = new RuleInfo();
				ruleInfo.setNodeNumber(classifierTree.getTreeNo().toString());
				ruleInfo.setSampleCount(classifierTree.getInstances().size());
				if(sampleCount > 0) {
					ruleInfo.setHitRatio((double) classifierTree.getInstances().size() / sampleCount);
					if(overdueLabelIndex >= 0) {
						double overdueRate = 1.0 * classifierTree.getDistributions()[overdueLabelIndex] / sampleCount;
						ruleInfo.setOverRate(overdueRate);
						if(sampleOverdueRate > 0) {
							ruleInfo.setLift( 1.0 * overdueRate / sampleOverdueRate);
						}
					}
				}
				//设置规则链
				List<String> ruleList = new ArrayList<>();
				for(FeatureBinsLabel label : binsLabel) {
					StringBuffer sb  = new StringBuffer();
					String featureName = label.getFeatureName();
					String labelStr = label.getBinsLabel();
					if(Constant.NOMINAL.equalsIgnoreCase(label.getFeatureType())) {
						sb.append(featureName).append(" in [").append(labelStr.trim().replace("\n", ",")).append("]");
					}
					else {
						if(labelStr.contains("NaN")) {
							sb.append(featureName).append(" = NaN");
						}
						else {
							String rang[] = labelStr.replace("[", "").replace(")", "").trim().split(",");
							sb
							.append(featureName).append(" >= ").append(rang[0].trim())
							.append(" && ")
							.append(featureName).append(" < ").append(rang[1].trim());
						}
					}
					ruleList.add(sb.toString());
				}
				ruleInfo.setRuleList(ruleList);
				//添加到集合
				infos.add(ruleInfo);
			}
		}
		else {
			//非叶子节点
			FeatureBinsLabel featureBinsLabel = new FeatureBinsLabel();
			featureBinsLabel.setFeatureName(classifierTree.getInstances().attribute(classifierTree.getSplitAttribute()).name());
			List<String> thisBinsLabels = getBinsToLabel(classifierTree.getSplitPoints());
			Attribute attribute = classifierTree.getInstances().attribute(classifierTree.getSplitAttribute());
			if(attribute.isNominal()) {
				featureBinsLabel.setFeatureType(Constant.NOMINAL);
			}
			else {
				featureBinsLabel.setFeatureType(Constant.NUMERIC);
			}
			for(int i=0; i<classifierTree.getChildNodes().length; i++) {
				featureBinsLabel.setBinsLabel(thisBinsLabels.get(i));
				ClassifierTree childNode = classifierTree.getChildNodes()[i];
				binsLabel.add(featureBinsLabel);
				//递归节点
				List<RuleInfo> childInfos = getRuleInfoForleafNodes(sampleCount, overdueLabelIndex, sampleOverdueRate, childNode, binsLabel);
				infos.addAll(childInfos);
				binsLabel.removeLast();
			}
		}
		//返回
		return infos;
	}
	
}
